{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from rcd_dataSet import RCD_DataSet\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DataSet: ASSIST_0910\n"
     ]
    }
   ],
   "source": [
    "# ----------基本参数--------------\n",
    "datadir = '../../'\n",
    "concept_graph_dir = './graph/'\n",
    "dataSet_list = ('FrcSub', 'Math1', 'Math2', 'ASSIST_0910', 'ASSIST_2017', 'MAT_2016', 'JUNYI')\n",
    "\n",
    "dataSet_idx = 3\n",
    "test_ratio = 0.8\n",
    "batch_size = 64\n",
    "learn_rate = 3e-2\n",
    "emb_dim = 16\n",
    "\n",
    "data_set_name = dataSet_list[dataSet_idx]\n",
    "epochs = 8\n",
    "device = 'cuda'\n",
    "# ----------基本参数--------------\n",
    "\n",
    "dataSet = RCD_DataSet(datadir, concept_graph_dir, data_set_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "HeteroData(\n",
       "  \u001b[1mstu\u001b[0m={ num_nodes=2392 },\n",
       "  \u001b[1mitem\u001b[0m={ num_nodes=17657 },\n",
       "  \u001b[1mconc\u001b[0m={ num_nodes=116 },\n",
       "  \u001b[1m(stu, answer, item)\u001b[0m={ edge_index=[2, 265900] },\n",
       "  \u001b[1m(item, contain, conc)\u001b[0m={ edge_index=[2, 17657] },\n",
       "  \u001b[1m(conc, relevant, conc)\u001b[0m={ edge_index=[2, 3364] }\n",
       ")"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataSet.all_graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "97"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(dataSet.item_conc_df['concept'].unique())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "def get_item_concept_df(item) -> pd.DataFrame:\n",
    "    item = item[~item.index.duplicated()]  # 去除重复项\n",
    "    item_list, concept_list = [], []\n",
    "    for idx in item.index:\n",
    "        now_concept_list = eval(item.loc[idx, 'knowledge_code'])\n",
    "        item_list.extend([idx] * len(now_concept_list))\n",
    "        concept_list.extend(now_concept_list)\n",
    "    return pd.DataFrame({'item': item_list, 'concept': concept_list}).astype('int')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "tmp=get_item_concept_df(dataSet.item)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "97"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(tmp['concept'].unique())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_item_concept_df2(item) -> pd.DataFrame:\n",
    "    item = item[~item.index.duplicated()]  # 去除重复项\n",
    "    item_list, concept_list = [], []\n",
    "    for idx in item.index:\n",
    "        now_concept_list = eval(item.loc[idx, 'knowledge_code'])\n",
    "        item_list.extend([idx] * len(now_concept_list))\n",
    "        concept_list.extend(now_concept_list)\n",
    "    return pd.DataFrame({'item': item_list, 'concept': concept_list}).astype('int')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>knowledge_code</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>item_id</th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>[10]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>[1]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>[10]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>[1]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>[10]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17653</th>\n",
       "      <td>[94]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17654</th>\n",
       "      <td>[94]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17655</th>\n",
       "      <td>[94]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17656</th>\n",
       "      <td>[94]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17657</th>\n",
       "      <td>[94]</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>17657 rows × 1 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "        knowledge_code\n",
       "item_id               \n",
       "1                 [10]\n",
       "2                  [1]\n",
       "3                 [10]\n",
       "4                  [1]\n",
       "5                 [10]\n",
       "...                ...\n",
       "17653             [94]\n",
       "17654             [94]\n",
       "17655             [94]\n",
       "17656             [94]\n",
       "17657             [94]\n",
       "\n",
       "[17657 rows x 1 columns]"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataSet.item"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test_ratio: 0.2\n"
     ]
    }
   ],
   "source": [
    "train, test = dataSet.get_train_test(dataSet.record, test_ratio=0.2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Graphing:]: 100%|██████████| 9/9 [00:00<00:00, 29.39it/s]\n"
     ]
    }
   ],
   "source": [
    "data_list=dataSet.get_data_loader(train,batch_size=batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "8576"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "HeteroData(\n",
       "  triple_index=[3, 2855],\n",
       "  mean_index=[2855],\n",
       "  label=[1024],\n",
       "  \u001b[1mstu\u001b[0m={\n",
       "    num_nodes=64,\n",
       "    pos=[64]\n",
       "  },\n",
       "  \u001b[1mitem\u001b[0m={\n",
       "    num_nodes=20,\n",
       "    pos=[20]\n",
       "  },\n",
       "  \u001b[1mconc\u001b[0m={\n",
       "    num_nodes=8,\n",
       "    pos=[8]\n",
       "  },\n",
       "  \u001b[1m(stu, answer, item)\u001b[0m={ edge_index=[2, 1024] },\n",
       "  \u001b[1m(item, contain, conc)\u001b[0m={ edge_index=[2, 56] },\n",
       "  \u001b[1m(stu, master, conc)\u001b[0m={ edge_index=[2, 512] }\n",
       ")"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_list[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of paramerters in networks is 12721  \n"
     ]
    }
   ],
   "source": [
    "from models import RCD\n",
    "\n",
    "model=RCD(dataSet.stu_num,dataSet.prob_num,dataSet.concept_num,emb_dim=emb_dim,learn_rate=learn_rate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Epoch:1]: 100%|██████████| 9/9 [00:00<00:00, 20.99it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\t1 epoch loss = 0.7065359817610847\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Epoch:2]: 100%|██████████| 9/9 [00:00<00:00, 26.70it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\t2 epoch loss = 0.6982975900173187\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "HeteroData(\n",
       "  \u001b[1mstu\u001b[0m={\n",
       "    num_nodes=536,\n",
       "    x=[536, 16]\n",
       "  },\n",
       "  \u001b[1mitem\u001b[0m={\n",
       "    num_nodes=20,\n",
       "    x=[20, 16]\n",
       "  },\n",
       "  \u001b[1mconc\u001b[0m={\n",
       "    num_nodes=8,\n",
       "    x=[8, 16]\n",
       "  },\n",
       "  \u001b[1m(stu, answer, item)\u001b[0m={ edge_index=[2, 10720] },\n",
       "  \u001b[1m(item, contain, conc)\u001b[0m={ edge_index=[2, 56] },\n",
       "  \u001b[1m(conc, relevant, conc)\u001b[0m={ edge_index=[2, 53] }\n",
       ")"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit(dataSet.all_graph,data_list,epochs=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DataSet: FrcSub\n",
      "test_ratio: 0.2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[split record:]: 100%|██████████| 536/536 [00:00<00:00, 3656.05it/s]\n",
      "[Graphing:]: 100%|██████████| 536/536 [00:02<00:00, 247.02it/s]\n"
     ]
    }
   ],
   "source": [
    "from rcd_dataSet import RCD_DataSet\n",
    "\n",
    "datadir = '../../'\n",
    "concept_graph_dir = './graph/'\n",
    "dataSet_list = ('FrcSub','ASSIST_0910', 'ASSIST_2017', 'JUNYI', 'MathEC', 'KDDCUP')\n",
    "\n",
    "dataSet_idx = 0\n",
    "test_ratio = 0.2\n",
    "batch_size = 32\n",
    "learn_rate = 3e-3\n",
    "\n",
    "data_set_name = dataSet_list[dataSet_idx]\n",
    "epochs = 80\n",
    "# ----------基本参数--------------\n",
    "\n",
    "dataSet = RCD_DataSet(datadir, concept_graph_dir, data_set_name)\n",
    "\n",
    "train, test = dataSet.get_train_test(dataSet.record, test_ratio=test_ratio)\n",
    "# data_list = dataSet.get_data_loader(train, batch_size=batch_size)\n",
    "test_data_list = dataSet.get_data_loader(test, batch_size=batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "for x in test_data_list:\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "RCDDataBatch(\n",
       "  stu_track=[355],\n",
       "  item_track=[355],\n",
       "  conc_track=[355],\n",
       "  mean_track=[355],\n",
       "  label=[128],\n",
       "  \u001b[1mstu\u001b[0m={\n",
       "    num_nodes=32,\n",
       "    pos=[32],\n",
       "    batch=[32],\n",
       "    ptr=[33]\n",
       "  },\n",
       "  \u001b[1mitem\u001b[0m={\n",
       "    num_nodes=128,\n",
       "    pos=[128],\n",
       "    batch=[128],\n",
       "    ptr=[33]\n",
       "  },\n",
       "  \u001b[1mconc\u001b[0m={\n",
       "    num_nodes=184,\n",
       "    pos=[184],\n",
       "    batch=[184],\n",
       "    ptr=[33]\n",
       "  },\n",
       "  \u001b[1m(stu, answer, item)\u001b[0m={ edge_index=[2, 128] },\n",
       "  \u001b[1m(item, contain, conc)\u001b[0m={ edge_index=[2, 355] },\n",
       "  \u001b[1m(stu, master, conc)\u001b[0m={ edge_index=[2, 184] }\n",
       ")"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([  0,   1,   2,   3,   3,   4,   3,   5,   3,   6,   7,   8,   9,  10,\n",
       "          7,   8,   9,   6,  10,   8,  11,   8,  12,  13,  14,  15,  13,  14,\n",
       "         14,  15,  14,  16,  17,  18,  19,  20,  16,  20,  17,  20,  17,  21,\n",
       "         22,  23,  24,  25,  22,  26,  25,  27,  25,  27,  21,  25,  28,  29,\n",
       "         30,  31,  32,  28,  33,  29,  31,  28,  34,  29,  31,  28,  34,  31,\n",
       "         35,  36,  37,  38,  39,  36,  40,  38,  41,  38,  38,  42,  43,  44,\n",
       "         45,  46,  45,  46,  47,  45,  47,  43,  48,  49,  45,  50,  51,  52,\n",
       "         53,  52,  50,  52,  54,  50,  55,  51,  52,  56,  57,  56,  57,  58,\n",
       "         59,  56,  57,  60,  59,  56,  61,  62,  63,  64,  64,  65,  64,  61,\n",
       "         63,  66,  64,  67,  68,  69,  67,  70,  69,  71,  67,  69,  70,  69,\n",
       "         72,  72,  73,  74,  72,  74,  75,  76,  72,  77,  78,  79,  80,  81,\n",
       "         79,  82,  79,  80,  83,  83,  84,  85,  86,  87,  85,  83,  88,  84,\n",
       "         85,  89,  90,  91,  92,  89,  90,  93,  91,  89,  94,  93,  91,  89,\n",
       "         93,  95,  91,  96,  97,  98,  99,  96,  97,  99,  96, 100,  97,  96,\n",
       "        101, 102, 103, 102, 104, 105, 103, 101, 103, 102, 106, 103, 107, 108,\n",
       "        109, 110, 108, 111, 109, 110, 110, 112, 108, 113, 114, 115, 116, 117,\n",
       "        114, 115, 116, 117, 115, 116, 117, 115, 118, 119, 120, 121, 122, 123,\n",
       "        122, 119, 122, 124, 125, 126, 127, 125, 126, 128, 129, 127, 130, 131,\n",
       "        126, 127, 131, 126, 128, 132, 133, 132, 134, 133, 135, 133, 132, 134,\n",
       "        133, 136, 137, 138, 139, 138, 139, 138, 139, 140, 138, 141, 142, 143,\n",
       "        144, 144, 145, 146, 144, 141, 146, 144, 145, 147, 148, 149, 150, 151,\n",
       "        147, 149, 152, 150, 147, 147, 153, 154, 153, 154, 155, 156, 157, 154,\n",
       "        155, 158, 156, 154, 159, 160, 161, 162, 159, 159, 162, 163, 164, 165,\n",
       "        166, 164, 167, 163, 164, 168, 169, 170, 169, 170, 171, 169, 170, 172,\n",
       "        173, 174, 172, 173, 175, 176, 173, 175, 173, 177, 178, 179, 180, 181,\n",
       "        182, 179, 179, 179, 183])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x.conc_track"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "RCDData(\n",
       "  \u001b[1mstu\u001b[0m={ num_nodes=536 },\n",
       "  \u001b[1mitem\u001b[0m={ num_nodes=20 },\n",
       "  \u001b[1mconc\u001b[0m={ num_nodes=8 },\n",
       "  \u001b[1m(stu, answer, item)\u001b[0m={ edge_index=[2, 10720] },\n",
       "  \u001b[1m(item, contain, conc)\u001b[0m={ edge_index=[2, 56] },\n",
       "  \u001b[1m(conc, relevant, conc)\u001b[0m={ edge_index=[2, 53] }\n",
       ")"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataSet.all_graph"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "527a93331b4b1a8345148922acc34427fb7591433d63b66d32040b6fbbc6d593"
  },
  "kernelspec": {
   "display_name": "Python 3.8.12 ('pytorch')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
